{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 网络参数的初始化\n",
    "\n",
    "[![](https://gitee.com/mindspore/docs/raw/master/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/master/docs/programming_guide/source_zh_cn/initializer.ipynb)&emsp;[![](https://gitee.com/mindspore/docs/raw/master/resource/_static/logo_notebook.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/master/programming_guide/mindspore_initializer.ipynb)&emsp;[![](https://gitee.com/mindspore/docs/raw/master/resource/_static/logo_modelarts.png)](https://authoring-modelarts-cnnorth4.huaweicloud.com/console/lab?share-url-b64=aHR0cHM6Ly9vYnMuZHVhbHN0YWNrLmNuLW5vcnRoLTQubXlodWF3ZWljbG91ZC5jb20vbWluZHNwb3JlLXdlYnNpdGUvbm90ZWJvb2svbW9kZWxhcnRzL3Byb2dyYW1taW5nX2d1aWRlL21pbmRzcG9yZV9pbml0aWFsaXplci5pcHluYg==&imagename=MindSpore1.1.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 概述\n",
    "\n",
    "MindSpore提供了权重初始化模块，用户可以通过封装算子和initializer方法来调用字符串、Initializer子类或自定义Tensor等方式完成对网络参数进行初始化。Initializer类是MindSpore中用于进行初始化的基本数据结构，其子类包含了几种不同类型的数据分布（Zero，One，XavierUniform，HeUniform，HeNormal，Constant，Uniform，Normal，TruncatedNormal）。下面针对封装算子和initializer方法两种参数初始化模式进行详细介绍。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用封装算子对参数初始化                                                                                 \n",
    "MindSpore提供了多种参数初始化的方式，并在部分算子中封装了参数初始化的功能。本节将介绍带有参数初始化功能的算子对参数进行初始化的方法，以`Conv2d`算子为例，分别介绍以字符串，`Initializer`子类和自定义`Tensor`等方式对网络中的参数进行初始化，以下代码示例中均以`Initializer`的子类`Normal`为例，代码示例中`Normal`均可替换成`Initializer`子类中任何一个。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 字符串                                                                                               \n",
    "使用字符串对网络参数进行初始化，字符串的内容需要与`Initializer`子类的名称保持一致，使用字符串方式进行初始化将使用`Initializer`子类中的默认参数，例如使用字符串`Normal`等同于使用`Initializer`的子类`Normal()`，代码样例如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[[ 3.10382620e-02  4.38603461e-02  4.38603461e-02 ...  4.38603461e-02\n",
      "     4.38603461e-02  1.38719045e-02]\n",
      "   [ 3.26051228e-02  3.54298912e-02  3.54298912e-02 ...  3.54298912e-02\n",
      "     3.54298912e-02 -5.54019120e-03]\n",
      "   [ 3.26051228e-02  3.54298912e-02  3.54298912e-02 ...  3.54298912e-02\n",
      "     3.54298912e-02 -5.54019120e-03]\n",
      "   ...\n",
      "   [ 3.26051228e-02  3.54298912e-02  3.54298912e-02 ...  3.54298912e-02\n",
      "     3.54298912e-02 -5.54019120e-03]\n",
      "   [ 3.26051228e-02  3.54298912e-02  3.54298912e-02 ...  3.54298912e-02\n",
      "     3.54298912e-02 -5.54019120e-03]\n",
      "   [ 9.66199022e-03  1.24104535e-02  1.24104535e-02 ...  1.24104535e-02\n",
      "     1.24104535e-02 -1.38977719e-02]]\n",
      "\n",
      "  ...\n",
      "\n",
      "  [[ 3.98553275e-02 -1.35465711e-03 -1.35465711e-03 ... -1.35465711e-03\n",
      "    -1.35465711e-03 -1.00310734e-02]\n",
      "   [ 4.38403059e-03 -3.60766202e-02 -3.60766202e-02 ... -3.60766202e-02\n",
      "    -3.60766202e-02 -2.95619294e-02]\n",
      "   [ 4.38403059e-03 -3.60766202e-02 -3.60766202e-02 ... -3.60766202e-02\n",
      "    -3.60766202e-02 -2.95619294e-02]\n",
      "   ...\n",
      "   [ 4.38403059e-03 -3.60766202e-02 -3.60766202e-02 ... -3.60766202e-02\n",
      "    -3.60766202e-02 -2.95619294e-02]\n",
      "   [ 4.38403059e-03 -3.60766202e-02 -3.60766202e-02 ... -3.60766202e-02\n",
      "    -3.60766202e-02 -2.95619294e-02]\n",
      "   [ 1.33139016e-02  6.74417242e-05  6.74417242e-05 ...  6.74417242e-05\n",
      "     6.74417242e-05 -2.27325838e-02]]]]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore.nn as nn\n",
    "from mindspore import Tensor\n",
    "from mindspore.common import set_seed\n",
    "\n",
    "set_seed(1)\n",
    "\n",
    "input_data = Tensor(np.ones([1, 3, 16, 50], dtype=np.float32))\n",
    "net = nn.Conv2d(3, 64, 3, weight_init='Normal')\n",
    "output = net(input_data)\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initializer子类                                                                                   \n",
    "使用`Initializer`子类对网络参数进行初始化，与使用字符串对参数进行初始化的效果类似，不同的是使用字符串进行参数初始化是使用`Initializer`子类的默认参数，如要使用`Initializer`子类中的参数，就必须使用`Initializer`子类的方式对参数进行初始化，以`Normal(0.2)`为例，代码样例如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[[ 6.2076533e-01  8.7720710e-01  8.7720710e-01 ...  8.7720710e-01\n",
      "     8.7720710e-01  2.7743810e-01]\n",
      "   [ 6.5210247e-01  7.0859784e-01  7.0859784e-01 ...  7.0859784e-01\n",
      "     7.0859784e-01 -1.1080378e-01]\n",
      "   [ 6.5210247e-01  7.0859784e-01  7.0859784e-01 ...  7.0859784e-01\n",
      "     7.0859784e-01 -1.1080378e-01]\n",
      "   ...\n",
      "   [ 6.5210247e-01  7.0859784e-01  7.0859784e-01 ...  7.0859784e-01\n",
      "     7.0859784e-01 -1.1080378e-01]\n",
      "   [ 6.5210247e-01  7.0859784e-01  7.0859784e-01 ...  7.0859784e-01\n",
      "     7.0859784e-01 -1.1080378e-01]\n",
      "   [ 1.9323981e-01  2.4820906e-01  2.4820906e-01 ...  2.4820906e-01\n",
      "     2.4820906e-01 -2.7795550e-01]]\n",
      "\n",
      "  ...\n",
      "\n",
      "  [[ 7.9710668e-01 -2.7093157e-02 -2.7093157e-02 ... -2.7093157e-02\n",
      "    -2.7093157e-02 -2.0062150e-01]\n",
      "   [ 8.7680638e-02 -7.2153252e-01 -7.2153252e-01 ... -7.2153252e-01\n",
      "    -7.2153252e-01 -5.9123868e-01]\n",
      "   [ 8.7680638e-02 -7.2153252e-01 -7.2153252e-01 ... -7.2153252e-01\n",
      "    -7.2153252e-01 -5.9123868e-01]\n",
      "   ...\n",
      "   [ 8.7680638e-02 -7.2153252e-01 -7.2153252e-01 ... -7.2153252e-01\n",
      "    -7.2153252e-01 -5.9123868e-01]\n",
      "   [ 8.7680638e-02 -7.2153252e-01 -7.2153252e-01 ... -7.2153252e-01\n",
      "    -7.2153252e-01 -5.9123868e-01]\n",
      "   [ 2.6627803e-01  1.3488382e-03  1.3488382e-03 ...  1.3488382e-03\n",
      "     1.3488382e-03 -4.5465171e-01]]]]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore.nn as nn\n",
    "from mindspore import Tensor\n",
    "from mindspore.common import set_seed\n",
    "from mindspore.common.initializer import Normal\n",
    "\n",
    "set_seed(1)\n",
    "\n",
    "input_data = Tensor(np.ones([1, 3, 16, 50], dtype=np.float32))\n",
    "net = nn.Conv2d(3, 64, 3, weight_init=Normal(0.2))\n",
    "output = net(input_data)\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 自定义的Tensor                                                                \n",
    "除上述两种初始化方法外，当网络要使用MindSpore中没有的数据类型对参数进行初始化，用户可以通过自定义`Tensor`的方式来对参数进行初始化，代码样例如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[[12. 18. 18. ... 18. 18. 12.]\n",
      "   [18. 27. 27. ... 27. 27. 18.]\n",
      "   [18. 27. 27. ... 27. 27. 18.]\n",
      "   ...\n",
      "   [18. 27. 27. ... 27. 27. 18.]\n",
      "   [18. 27. 27. ... 27. 27. 18.]\n",
      "   [12. 18. 18. ... 18. 18. 12.]]\n",
      "\n",
      "  ...\n",
      "\n",
      "  [[12. 18. 18. ... 18. 18. 12.]\n",
      "   [18. 27. 27. ... 27. 27. 18.]\n",
      "   [18. 27. 27. ... 27. 27. 18.]\n",
      "   ...\n",
      "   [18. 27. 27. ... 27. 27. 18.]\n",
      "   [18. 27. 27. ... 27. 27. 18.]\n",
      "   [12. 18. 18. ... 18. 18. 12.]]]]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore.nn as nn\n",
    "from mindspore import Tensor\n",
    "from mindspore import dtype as mstype\n",
    "\n",
    "weight = Tensor(np.ones([64, 3, 3, 3]), dtype=mstype.float32)\n",
    "input_data = Tensor(np.ones([1, 3, 16, 50], dtype=np.float32))\n",
    "net = nn.Conv2d(3, 64, 3, weight_init=weight)\n",
    "output = net(input_data)\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用initializer方法对参数初始化\n",
    "\n",
    "在上述代码样例中，给出了如何在网络中进行参数初始化的方法，如在网络中使用nn层封装`Conv2d`算子，参数`weight_init`作为要初始化的数据类型传入`Conv2d`算子，算子会在初始化时通过调用`Parameter`类，进而调用封装在`Parameter`类中的`initializer`方法来完成对参数的初始化。然而有一些算子并没有像`Conv2d`那样在内部对参数初始化的功能进行封装，如`Conv3d`算子的权重就是作为参数传入`Conv3d`算子，此时就需要手动的定义权重的初始化。\n",
    "\n",
    "当对参数进行初始化时，可以使用`initializer`方法调用`Initializer`子类中不同的数据类型来对参数进行初始化，进而产生不同类型的数据。\n",
    "\n",
    "使用initializer进行参数初始化时，支持传入的参数有`init`、`shape`、`dtype`：\n",
    "\n",
    "- `init`：支持传入`Tensor`、 `str`、 `Initializer的子类`。\n",
    "\n",
    "- `shape`：支持传入`list`、 `tuple`、 `int`。\n",
    "\n",
    "- `dtype`：支持传入`mindspore.dtype`。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### init参数为Tensor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "代码样例如下："
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-02-03T02:59:50.340750Z",
     "start_time": "2021-02-03T02:59:49.571048Z"
    }
   },
   "source": [
    "```python\n",
    "import numpy as np\n",
    "from mindspore import Tensor\n",
    "from mindspore import dtype as mstype\n",
    "from mindspore.common import set_seed\n",
    "from mindspore.common.initializer import initializer\n",
    "from mindspore.ops.operations import nn_ops as nps\n",
    "\n",
    "set_seed(1)\n",
    "\n",
    "input_data = Tensor(np.ones([16, 3, 10, 32, 32]), dtype=mstype.float32)\n",
    "weight_init = Tensor(np.ones([32, 3, 4, 3, 3]), dtype=mstype.float32)\n",
    "weight = initializer(weight_init, shape=[32, 3, 4, 3, 3])\n",
    "conv3d = nps.Conv3D(out_channel=32, kernel_size=(4, 3, 3))\n",
    "output = conv3d(input_data, weight)\n",
    "print(output)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "输出如下：\n",
    "\n",
    "```text\n",
    "[[[[[108 108 108 ... 108 108 108]\n",
    "    [108 108 108 ... 108 108 108]\n",
    "    [108 108 108 ... 108 108 108]\n",
    "    ...\n",
    "    [108 108 108 ... 108 108 108]\n",
    "    [108 108 108 ... 108 108 108]\n",
    "    [108 108 108 ... 108 108 108]]\n",
    "    ...\n",
    "   [[108 108 108 ... 108 108 108]\n",
    "    [108 108 108 ... 108 108 108]\n",
    "    [108 108 108 ... 108 108 108]\n",
    "    ...\n",
    "    [108 108 108 ... 108 108 108]\n",
    "    [108 108 108 ... 108 108 108]\n",
    "    [108 108 108 ... 108 108 108]]]]]\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### init参数为str"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "代码样例如下："
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```python\n",
    "import numpy as np\n",
    "from mindspore import Tensor\n",
    "from mindspore import dtype as mstype\n",
    "from mindspore.common import set_seed\n",
    "from mindspore.common.initializer import initializer\n",
    "from mindspore.ops.operations import nn_ops as nps\n",
    "\n",
    "set_seed(1)\n",
    "\n",
    "input_data = Tensor(np.ones([16, 3, 10, 32, 32]), dtype=mstype.float32)\n",
    "weight = initializer('Normal', shape=[32, 3, 4, 3, 3], dtype=mstype.float32)\n",
    "conv3d = nps.Conv3D(out_channel=32, kernel_size=(4, 3, 3))\n",
    "output = conv3d(input_data, weight)\n",
    "print(output)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "输出如下：\n",
    "\n",
    "```text\n",
    "[[[[[0 0 0 ... 0 0 0]\n",
    "    [0 0 0 ... 0 0 0]\n",
    "    [0 0 0 ... 0 0 0]]\n",
    "    ...\n",
    "    [0 0 0 ... 0 0 0]\n",
    "    [0 0 0 ... 0 0 0]\n",
    "    [0 0 0 ... 0 0 0]]\n",
    "    ...\n",
    "   [[0 0 0 ... 0 0 0]\n",
    "    [0 0 0 ... 0 0 0]\n",
    "    [0 0 0 ... 0 0 0]]\n",
    "    ...\n",
    "    [0 0 0 ... 0 0 0]\n",
    "    [0 0 0 ... 0 0 0]\n",
    "    [0 0 0 ... 0 0 0]]]]]\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### init参数为Initializer子类"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "代码样例如下："
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```python\n",
    "import numpy as np\n",
    "from mindspore import Tensor\n",
    "from mindspore import dtype as mstype\n",
    "from mindspore.common import set_seed\n",
    "from mindspore.ops.operations import nn_ops as nps\n",
    "from mindspore.common.initializer import Normal, initializer\n",
    "\n",
    "set_seed(1)\n",
    "\n",
    "input_data = Tensor(np.ones([16, 3, 10, 32, 32]), dtype=mstype.float32)\n",
    "weight = initializer(Normal(0.2), shape=[32, 3, 4, 3, 3], dtype=mstype.float32)\n",
    "conv3d = nps.Conv3D(out_channel=32, kernel_size=(4, 3, 3))\n",
    "output = conv3d(input_data, weight)\n",
    "print(output)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```text\n",
    "[[[[[0 0 0 ... 0 0 0]\n",
    "    [0 0 0 ... 0 0 0]\n",
    "    [0 0 0 ... 0 0 0]]\n",
    "    ...\n",
    "    [0 0 0 ... 0 0 0]\n",
    "    [0 0 0 ... 0 0 0]\n",
    "    [0 0 0 ... 0 0 0]]\n",
    "    ...\n",
    "   [[0 0 0 ... 0 0 0]\n",
    "    [0 0 0 ... 0 0 0]\n",
    "    [0 0 0 ... 0 0 0]]\n",
    "    ...\n",
    "    [0 0 0 ... 0 0 0]\n",
    "    [0 0 0 ... 0 0 0]\n",
    "    [0 0 0 ... 0 0 0]]]]]\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 在Parameter中的应用"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "代码样例如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[-0.3305102  1.0412874  2.0412874  3.0412874]\n",
      " [ 4.0412874  4.9479127  5.9479127  6.9479127]\n",
      " [ 7.947912   9.063009  10.063009  11.063009 ]\n",
      " [12.063009  13.536987  14.536987  14.857441 ]\n",
      " [15.751231  17.073082  17.808317  19.364822 ]]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from mindspore import dtype as mstype\n",
    "from mindspore.common import set_seed\n",
    "from mindspore.ops import operations as ops\n",
    "from mindspore import Tensor, Parameter, context\n",
    "from mindspore.common.initializer import Normal, initializer\n",
    "\n",
    "set_seed(1)\n",
    "\n",
    "weight1 = Parameter(initializer('Normal', [5, 4], mstype.float32), name=\"w1\")\n",
    "weight2 = Parameter(initializer(Normal(0.2), [5, 4], mstype.float32), name=\"w2\")\n",
    "input_data = Tensor(np.arange(20).reshape(5, 4), dtype=mstype.float32)\n",
    "net = ops.Add()\n",
    "output = net(input_data, weight1)\n",
    "output = net(output, weight2)\n",
    "print(output)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
